from .transformer import TransformerEncoder
from .mlp import ResMLPEncoder, MLPEncoder
from torch import nn

def setup_encoder(model_type, num_hidden, num_layers, dropout, activation, norm, nhead) -> nn.Module:
    if model_type in ["performer", "cosformer", "transformer"]:
        mod = TransformerEncoder(
            num_hidden = num_hidden,
            nhead = nhead,
            num_layers = num_layers,
            dropout = dropout,
            activation = activation,
            norm = norm,
            model_type = model_type,
        )
    elif model_type == 'mlp':
        mod = MLPEncoder(
            num_hidden=num_hidden,
            num_layers=num_layers,
            dropout=dropout,
            norm=norm,
        )
    elif model_type == "resmlp":
        mod = ResMLPEncoder(
            num_hidden=num_hidden,
            num_layers=num_layers,
            dropout=dropout,
            norm=norm,
        )
    else:
        raise NotImplementedError(f'Unsupported model type: {model_type}')
    return mod